#!/usr/bin/env python3
"""
bell_test.py

Implementation of a CHSH Bell test using the tick‑flip algebra.  This script
either leverages the ``ar_sim`` library from ``ar-operator-core`` when
available or falls back to an analytic simulation of the maximally entangled
singlet state.  The CHSH correlator is defined as

    S = E(a,b) - E(a,b') + E(a',b) + E(a',b'),

where each E(x,y) is the average of the product of two ±1 measurement outcomes.
For local hidden variable theories the bound |S| ≤ 2 holds, whereas quantum
mechanics permits values up to 2√2.  A violation (|S|>2) therefore
demonstrates non‑locality.

Running this script will sweep over a set of measurement angles, estimate the
correlations, compute S, save the results to ``results/chsh_values.csv`` and
create a plot ``results/S_vs_angle.png``.
"""

import os
import sys
import math
import numpy as np

# Attempt to locate the ar-operator-core package relative to this repository.
# When this repository lives alongside ar-operator-core (as in the integrated
# workspace) the following path manipulation adds it to the import search path
# without requiring installation.
_core_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'ar-operator-core'))
if _core_path not in sys.path and os.path.isdir(_core_path):
    sys.path.insert(0, _core_path)

# Try to import the tick‑flip algebra and projector definitions.  If the
# import fails (for example when running in a clean environment without
# ar-operator-core) the script falls back to an analytic model of a maximally
# entangled singlet state.
use_ar_sim = False
try:
    from ar_sim.tick_algebra import TickOperatorSystem
    from ar_sim.projectors import measurement_projector
    use_ar_sim = True
except Exception:
    use_ar_sim = False


def estimate_E(angle_A: float, angle_B: float, trials: int = 10000) -> float:
    """Estimate the correlator E(angle_A, angle_B).

    When ``ar_sim`` is available this function performs a Monte Carlo
    simulation by repeatedly measuring the entangled pair along the
    specified angles.  Otherwise it returns the analytic value for a
    singlet state, namely ``-cos(angle_A - angle_B)``.

    Parameters
    ----------
    angle_A : float
        Measurement angle for party A (in radians).
    angle_B : float
        Measurement angle for party B (in radians).
    trials : int, optional
        Number of trials to perform when sampling with ar_sim.  Ignored in
        analytic mode.

    Returns
    -------
    float
        Estimated correlator E(angle_A, angle_B).
    """
    if use_ar_sim:
        # Lazily create the system and entangled pair on first invocation.
        global _system, _psi_entangled
        # Initialise global objects if they do not already exist.
        if '_system' not in globals():
            _system = TickOperatorSystem()
            _psi_entangled = _system.create_maximally_entangled_pair()
        P_A = measurement_projector(angle=angle_A)
        P_B = measurement_projector(angle=angle_B)
        counts = []
        for _ in range(trials):
            outcome_A = _system.measure(_psi_entangled, P_A)
            outcome_B = _system.measure(_psi_entangled, P_B)
            counts.append(outcome_A * outcome_B)
        return float(np.mean(counts))
    else:
        # Analytic expectation value for a maximally entangled singlet.
        return -math.cos(angle_A - angle_B)


def main() -> None:
    """Execute the Bell test simulation.

    Sweeps the measurement angle ``b`` from 0 to π/2 in equal increments
    and computes the corresponding CHSH S values with fixed angles
    ``a = 0`` and ``a' = π/4`` and ``b' = b + π/4``.  The results are
    written to a CSV and a plot is generated.
    """
    import argparse
    parser = argparse.ArgumentParser(description="Run CHSH Bell-test using tick‑flip algebra or analytic fallback.")
    parser.add_argument('--trials', type=int, default=10000,
                        help='Number of trials for Monte Carlo estimation (only used when ar_sim is available)')
    parser.add_argument('--points', type=int, default=9,
                        help='Number of angle points between 0 and π/2 inclusive for b')
    args = parser.parse_args()

    # Fixed measurement settings for parties A and A'.  The canonical choice
    # that maximally violates the CHSH inequality uses angles differing by
    # π/2 between the two settings on each side.  We therefore take
    # a = 0 and a' = π/2.  With this choice the optimal pair for B and B'
    # is b = π/4 and b' = 3π/4.  To visualise the violation over a range of
    # settings we sweep b over [0, π] and set b' = b + π/2.
    a = 0.0
    ap = math.pi / 2

    # Generate equally spaced points for b in [0, π]
    b_values = np.linspace(0.0, math.pi, args.points)

    results = []
    for b in b_values:
        # b' differs from b by π/2
        bp = b + math.pi / 2
        E_ab = estimate_E(a, b, trials=args.trials)
        E_abp = estimate_E(a, bp, trials=args.trials)
        E_apb = estimate_E(ap, b, trials=args.trials)
        E_apbp = estimate_E(ap, bp, trials=args.trials)
        S_val = E_ab - E_abp + E_apb + E_apbp
        results.append({'a': a, 'ap': ap, 'b': b, 'bp': bp, 'S': S_val})

    # Determine output directories relative to repository root
    repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
    results_dir = os.path.join(repo_root, 'results')
    os.makedirs(results_dir, exist_ok=True)

    # Write CSV
    csv_path = os.path.join(results_dir, 'chsh_values.csv')
    import csv
    with open(csv_path, 'w', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=['a', 'ap', 'b', 'bp', 'S'])
        writer.writeheader()
        for row in results:
            writer.writerow(row)

    # Plot S vs b
    import matplotlib.pyplot as plt
    S_values = [row['S'] for row in results]
    plt.figure()
    plt.plot(b_values, S_values, marker='o', label='Simulated S')
    plt.axhline(2.0, color='red', linestyle='--', label='Local hidden variable bound (2)')
    plt.axhline(2.0 * math.sqrt(2.0), color='green', linestyle='--', label='Tsirelson bound (2√2)')
    plt.xlabel('b (radians)')
    plt.ylabel('CHSH S-value')
    plt.title("CHSH S vs b for fixed a=0, a'=π/2 and b'=b+π/2")
    plt.legend()
    plt.tight_layout()
    plot_path = os.path.join(results_dir, 'S_vs_angle.png')
    plt.savefig(plot_path)
    plt.close()

    # Print a summary to stdout
    # Report the largest magnitude of S to quantify the violation.  The sign of S
    # is not physically important—only its magnitude relative to 2 matters.
    max_S_abs = max(abs(s) for s in S_values)
    print(f"[bell_test] Completed CHSH sweep: max |S| = {max_S_abs:.4f}")


if __name__ == '__main__':
    main()